{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e45f9b60-cd6b-4c15-958f-1feca5438128",
   "metadata": {},
   "source": [
    "# SQL Query Engine with LlamaIndex + DuckDB\n",
    "\n",
    "This guide showcases the core LlamaIndex SQL capabilities with DuckDB. \n",
    "\n",
    "We go through some core LlamaIndex data structures, including the `NLSQLTableQueryEngine` and `SQLTableRetrieverQueryEngine`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89cff478-98d8-464a-ac61-c04bb521ebb4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install duckdb duckdb-engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fbd7317b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95d4acf1-910c-4788-b2a0-6ea80b360470",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llama_index import SQLDatabase, SimpleDirectoryReader, WikipediaReader, Document\n",
    "from llama_index.indices.struct_store import (\n",
    "    NLSQLTableQueryEngine,\n",
    "    SQLTableRetrieverQueryEngine,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "53bf3d21-1321-4578-8be5-e752eebc879e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from IPython.display import Markdown, display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d57d7f32-92cc-485d-9cc1-884268d56377",
   "metadata": {},
   "source": [
    "## Basic Text-to-SQL with our `NLSQLTableQueryEngine` \n",
    "\n",
    "In this initial example, we walk through populating a SQL database with some test datapoints, and querying it with our text-to-SQL capabilities."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "461438c8-302d-45c5-8e69-16ad604686d1",
   "metadata": {},
   "source": [
    "### Create Database Schema + Test Data\n",
    "\n",
    "We use sqlalchemy, a popular SQL database toolkit, to connect to DuckDB and create an empty `city_stats` Table. We then populate it with some test data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a370b266-66f5-4624-bbf9-2ad57f0511f8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sqlalchemy import (\n",
    "    create_engine,\n",
    "    MetaData,\n",
    "    Table,\n",
    "    Column,\n",
    "    String,\n",
    "    Integer,\n",
    "    select,\n",
    "    column,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ea24f794-f10b-42e6-922d-9258b7167405",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "engine = create_engine(\"duckdb:///:memory:\")\n",
    "# uncomment to make this work with MotherDuck\n",
    "# engine = create_engine(\"duckdb:///md:llama-index\")\n",
    "metadata_obj = MetaData()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b4154b29-7e23-4c26-a507-370a66186ae7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# create city SQL table\n",
    "table_name = \"city_stats\"\n",
    "city_stats_table = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"city_name\", String(16), primary_key=True),\n",
    "    Column(\"population\", Integer),\n",
    "    Column(\"country\", String(16), nullable=False),\n",
    ")\n",
    "\n",
    "metadata_obj.create_all(engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4768bcb4-c40e-4d5d-8d70-7cb3228b50ab",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['city_stats'])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# print tables\n",
    "metadata_obj.tables.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c0eb518-5da3-4215-8280-0776d07806a0",
   "metadata": {},
   "source": [
    "We introduce some test data into the `city_stats` table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d15192b6-99f9-4f72-b637-82e885ea057f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sqlalchemy import insert\n",
    "\n",
    "rows = [\n",
    "    {\"city_name\": \"Toronto\", \"population\": 2930000, \"country\": \"Canada\"},\n",
    "    {\"city_name\": \"Tokyo\", \"population\": 13960000, \"country\": \"Japan\"},\n",
    "    {\"city_name\": \"Chicago\", \"population\": 2679000, \"country\": \"United States\"},\n",
    "    {\"city_name\": \"Seoul\", \"population\": 9776000, \"country\": \"South Korea\"},\n",
    "]\n",
    "for row in rows:\n",
    "    stmt = insert(city_stats_table).values(**row)\n",
    "    with engine.connect() as connection:\n",
    "        cursor = connection.execute(stmt)\n",
    "        connection.commit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bfc2e4a4-e11d-4d8f-bf1f-7f777a1dc6e2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]\n"
     ]
    }
   ],
   "source": [
    "with engine.connect() as connection:\n",
    "    cursor = connection.exec_driver_sql(\"SELECT * FROM city_stats\")\n",
    "    print(cursor.fetchall())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b553eac-1c51-48bf-bf00-918f206770f4",
   "metadata": {},
   "source": [
    "### Create SQLDatabase Object\n",
    "\n",
    "We first define our SQLDatabase abstraction (a light wrapper around SQLAlchemy)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7d3c0312-b9ab-4f59-83ee-94399d620cae",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llama_index import SQLDatabase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "94c83270-ceaf-4084-b142-ca8b3a5bc9ca",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/duckdb_engine/__init__.py:162: DuckDBEngineWarning: duckdb-engine doesn't yet support reflection on indices\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "051a171f-8c97-40ed-ae17-4e3fa3785487",
   "metadata": {},
   "source": [
    "### Query Index"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91139712-f232-47e1-9683-cbbd49cd331b",
   "metadata": {},
   "source": [
    "Here we demonstrate the capabilities of `NLSQLTableQueryEngine`, which performs text-to-SQL.\n",
    "\n",
    "1. We construct a `NLSQLTableQueryEngine` and pass in our SQL database object.\n",
    "2. We run queries against the query engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cdb58997-2258-4305-98a1-4a7178c30210",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "query_engine = NLSQLTableQueryEngine(sql_database)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "eabededd-3c17-45b7-aabc-06a2457bc3cb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.indices.struct_store.sql_query:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .\n",
      "> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/langchain/sql_database.py:238: UserWarning: This method is deprecated - please use `get_usable_table_names`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 332 tokens\n",
      "> [query] Total LLM token usage: 332 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n",
      "> [query] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"Which city has the highest population?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "25c11645-56bd-433a-85f4-420413f8970d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Tokyo has the highest population, with 13,960,000 people.'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3f72abc6-54d7-4f85-abf8-32978d94f558",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'result': [('Tokyo', 13960000)],\n",
       " 'sql_query': 'SELECT city_name, population \\nFROM city_stats \\nORDER BY population DESC \\nLIMIT 1;'}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response.metadata"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a97908ac-0ace-49a2-8898-863c19213c02",
   "metadata": {},
   "source": [
    "## Advanced Text-to-SQL with our `SQLTableRetrieverQueryEngine` \n",
    "\n",
    "In this guide, we tackle the setting where you have a large number of tables in your database, and putting all the table schemas into the prompt may overflow the text-to-SQL prompt.\n",
    "\n",
    "We first index the schemas with our `ObjectIndex`, and then use our `SQLTableRetrieverQueryEngine` abstraction on top."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "25c89e3a-eca0-41f1-b4ed-d76c2599b74a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "engine = create_engine(\"duckdb:///:memory:\")\n",
    "# uncomment to make this work with MotherDuck\n",
    "# engine = create_engine(\"duckdb:///md:llama-index\")\n",
    "metadata_obj = MetaData()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "1840102f-d29f-4f5e-81b6-ddd339a8243e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# create city SQL table\n",
    "table_name = \"city_stats\"\n",
    "city_stats_table = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"city_name\", String(16), primary_key=True),\n",
    "    Column(\"population\", Integer),\n",
    "    Column(\"country\", String(16), nullable=False),\n",
    ")\n",
    "all_table_names = [\"city_stats\"]\n",
    "# create a ton of dummy tables\n",
    "n = 100\n",
    "for i in range(n):\n",
    "    tmp_table_name = f\"tmp_table_{i}\"\n",
    "    tmp_table = Table(\n",
    "        tmp_table_name,\n",
    "        metadata_obj,\n",
    "        Column(f\"tmp_field_{i}_1\", String(16), primary_key=True),\n",
    "        Column(f\"tmp_field_{i}_2\", Integer),\n",
    "        Column(f\"tmp_field_{i}_3\", String(16), nullable=False),\n",
    "    )\n",
    "    all_table_names.append(f\"tmp_table_{i}\")\n",
    "\n",
    "metadata_obj.create_all(engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "1680238b-3b4d-41f2-8a75-cb3fbca9fff6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# insert dummy data\n",
    "from sqlalchemy import insert\n",
    "\n",
    "rows = [\n",
    "    {\"city_name\": \"Toronto\", \"population\": 2930000, \"country\": \"Canada\"},\n",
    "    {\"city_name\": \"Tokyo\", \"population\": 13960000, \"country\": \"Japan\"},\n",
    "    {\"city_name\": \"Chicago\", \"population\": 2679000, \"country\": \"United States\"},\n",
    "    {\"city_name\": \"Seoul\", \"population\": 9776000, \"country\": \"South Korea\"},\n",
    "]\n",
    "for row in rows:\n",
    "    stmt = insert(city_stats_table).values(**row)\n",
    "    with engine.connect() as connection:\n",
    "        cursor = connection.execute(stmt)\n",
    "        connection.commit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "e21f5889-47a7-489f-830a-3dd1c3347e53",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "534aff98-3b86-4f85-8386-2d7f22c0d91d",
   "metadata": {},
   "source": [
    "### Construct Object Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "5a6a30b9-9b80-4219-abf9-9cb8b8376ce6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine\n",
    "from llama_index.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema\n",
    "from llama_index import VectorStoreIndex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "b40afb95-4892-4348-8222-571fbdb5f21c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
      "> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 6343 tokens\n",
      "> [build_index_from_nodes] Total embedding token usage: 6343 tokens\n"
     ]
    }
   ],
   "source": [
    "table_node_mapping = SQLTableNodeMapping(sql_database)\n",
    "\n",
    "table_schema_objs = []\n",
    "for table_name in all_table_names:\n",
    "    table_schema_objs.append(SQLTableSchema(table_name=table_name))\n",
    "\n",
    "obj_index = ObjectIndex.from_objects(\n",
    "    table_schema_objs,\n",
    "    table_node_mapping,\n",
    "    VectorStoreIndex,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "862801a2-8f31-4c97-b3e2-f18bc679d774",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Query Index with `SQLTableRetrieverQueryEngine`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "ec5038cf-7816-4cde-af5d-94ad5c552c5b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "query_engine = SQLTableRetrieverQueryEngine(\n",
    "    sql_database,\n",
    "    obj_index.as_retriever(similarity_top_k=1),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "dbad6208-dc80-4f9d-82ae-770d20e8b59d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens\n",
      "> [retrieve] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 7 tokens\n",
      "> [retrieve] Total embedding token usage: 7 tokens\n",
      "INFO:llama_index.indices.struct_store.sql_query:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .\n",
      "> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .\n",
      "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 337 tokens\n",
      "> [query] Total LLM token usage: 337 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n",
      "> [query] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"Which city has the highest population?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "5d9c29f0-3736-42d1-bfcd-e78ece8ca094",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Response(response=' The city with the highest population is Tokyo, with a population of 13,960,000.', source_nodes=[], metadata={'result': [('Tokyo', 13960000)], 'sql_query': 'SELECT city_name, population \\nFROM city_stats \\nORDER BY population DESC \\nLIMIT 1;'})"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1085de3-72fd-4aca-a25e-6d7aa1763a89",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
